import numpy as np
import pandas as pd
import statsmodels.api as sm
from sklearn.metrics import r2_score
import shap
import matplotlib.pyplot as plt

# A GBT model is assumed to be pre-trained and available as `model`


def get_shap_values(model, X):
    """
    Initialize the SHAP TreeExplainer and compute SHAP interaction values.
    TODO: May need to adjust for different model types.

    Parameters:
        model: Trained tree-based model (e.g., XGBoost, LightGBM, etc.)
        X: Input data (numpy array or pandas DataFrame)

    Returns:
        shap_interaction_values: SHAP interaction values for X
    """
    explainer = shap.TreeExplainer(model)
    shap_interaction_values = explainer.shap_interaction_values(X)
    return shap_interaction_values


def get_feature_names(X):
    """
    Get feature names from X if available (e.g., pandas DataFrame), otherwise generate default names.
    
    Parameters:
        X: Input data (numpy array or pandas DataFrame)
    
    Returns:
        feature_names: List of feature names
        n_features: Number of features
    """
    if hasattr(X, 'columns'):
        feature_names = list(X.columns)
    else:
        feature_names = [f"x_{i}" for i in range(X.shape[1])]
    n_features = len(feature_names)
    return feature_names, n_features

def get_mean_interactions(shap_interaction_values, feature_names):
    """
    Calculate mean SHAP interaction values for all feature pairs.
    This serves as a baseline for measuring interaction strength.

    Parameters:
        shap_interaction_values: SHAP interaction values array (n_samples, n_features, n_features)
        feature_names: List of feature names

    Returns:
        pandas DataFrame with columns: feature_i, feature_j, mean_interaction
        Sorted by mean_interaction in descending order
    """
    # Ensure the number of features matches between SHAP values and feature names
    assert shap_interaction_values.shape[1] == len(feature_names), "Mismatch in number of features"
    n_features = len(feature_names)

    results = []

    for i in range(n_features):
        for j in range(i + 1, n_features):
            # Get interaction values for feature pair (i,j)
            # SHAP interaction matrix is symmetric, so we take both (i,j) and (j,i)
            interaction_ij = shap_interaction_values[:, i, j]
            interaction_ji = shap_interaction_values[:, j, i]
            
            # Combine the symmetric interactions
            combined_interactions = interaction_ij + interaction_ji
            
            # Calculate simple mean
            mean_interaction = np.mean(combined_interactions)
            
            results.append({
                'i': i,
                'j': j,
                'feature_i': feature_names[i],
                'feature_j': feature_names[j],
                'mean_interaction': mean_interaction,
                'mean_abs_interaction': np.abs(mean_interaction),
            })

    # Convert to DataFrame and sort by interaction strength (descending)
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('mean_abs_interaction', ascending=False).reset_index(drop=True)
    
    return results_df


def get_shap_mean_baseline(model, X):
    """
    Compute SHAP interaction values and calculate mean interaction strength baseline.

    Parameters:
        model: trained tree-based model (used by get_shap_values)
        X: numpy array or pandas DataFrame of input features

    Returns:
        pandas DataFrame with interaction rankings based on mean SHAP values
    """
    # Compute SHAP interaction values
    shap_values = get_shap_values(model, X)
    if isinstance(shap_values, list):
        shap_values = shap_values[0]  # Handle multiclass case

    # Get feature names and calculate mean interactions
    feature_names, _ = get_feature_names(X)
    return get_mean_interactions(shap_values, feature_names)

